import dgl
import torch
import torch.nn.functional as F
import numpy
import argparse
import time
import numpy as np
import os
import logging
import random
from dataset import Dataset
from sklearn.metrics import f1_score, accuracy_score, recall_score, roc_auc_score, precision_score, confusion_matrix
from BWGNN import *
from sklearn.model_selection import train_test_split
import torch.nn as nn
global device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 设置随机种子为72
def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    dgl.random.seed(seed)
    random.seed(seed)

# 计算几何平均值Gmean
def geometric_mean(recall_0, recall_1):
    return np.sqrt(recall_0 * recall_1)

# 将正负样本分开
def pos_neg_split(indices, labels):
    pos_indices = []
    neg_indices = []
    
    for idx in indices:
        if labels[idx] == 1:
            pos_indices.append(idx)
        else:
            neg_indices.append(idx)
            
    return pos_indices, neg_indices

def nt_xent_loss(z_i, z_j, temperature=0.1):
    """
    NT-Xent Loss (Normalised Temperature-scaled Cross Entropy Loss)
    
    :param z_i: Tensor, representations of the first augmented view.
    :param z_j: Tensor, representations of the second augmented view.
    :param temperature: Float, temperature scaling factor for the loss function.
    """
    device = z_i.device
    
    # Normalize the feature vectors
    z_i = F.normalize(z_i, dim=-1)
    z_j = F.normalize(z_j, dim=-1)
    
    # Concatenate the features from both views
    representations = torch.cat([z_i, z_j], dim=0)
    
    # Compute similarity matrix
    sim_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=-1)
    
    # Create labels for positive and negative pairs
    labels = torch.arange(z_i.size(0), device=device)
    labels = torch.cat([labels for _ in range(2)], dim=0)
    masks = labels[:, None] == labels[None, :]
    
    # Mask out self-similarity terms
    mask_diag = ~torch.eye(labels.size(0), dtype=torch.bool, device=device)
    sim_matrix = sim_matrix[mask_diag].view(labels.size(0), -1)
    masks = masks[mask_diag].view(labels.size(0), -1)
    
    # Compute the InfoNCE loss
    nominator = torch.exp(sim_matrix / temperature)[masks].view(labels.size(0), -1).sum(dim=-1)
    denominator = torch.sum(torch.exp(sim_matrix / temperature), dim=-1)
    loss = -torch.log(nominator / denominator).mean()
    
    return loss

def generate_contrastive_pairs(batch_nodes, labels):
            """
            根据给定的batch nodes生成正样本对和负样本对。
            
            :param batch_nodes: 当前批次中的节点索引列表
            :param labels: 节点标签
            :param feat_data: 节点特征数据
            :return: 一个包含(positive_pairs, negative_pairs)的元组
            """
            positive_pairs = []
            negative_pairs = []
            
            # 将CUDA张量转移到CPU，转换为NumPy数组
            if isinstance(labels, torch.Tensor) and labels.is_cuda:
                labels_cpu = labels.cpu().numpy()
            else:
                labels_cpu = labels

            # 确保batch_nodes也在CPU上
            if isinstance(batch_nodes, torch.Tensor) and batch_nodes.is_cuda:
                batch_nodes_cpu = batch_nodes.cpu().numpy()
            else:
                batch_nodes_cpu = batch_nodes

            for node in batch_nodes_cpu:
                # 正样本对：假设同类别节点作为正样本
                same_class_nodes = np.where(labels_cpu == labels_cpu[node])[0]
                if len(same_class_nodes) > 1:
                    pos_pair = np.random.choice(same_class_nodes[same_class_nodes != node], 1)[0]
                    positive_pairs.append((node, pos_pair))

                # 负样本对：随机选取不同类别的节点
                diff_class_nodes = np.where(labels_cpu != labels_cpu[node])[0]
                if len(diff_class_nodes) > 0:
                    neg_pair = np.random.choice(diff_class_nodes, 1)[0]
                    negative_pairs.append((node, neg_pair))
            
            return positive_pairs, negative_pairs
# GradientAwareFocalLoss实现
class GradientAwareFocalLoss(nn.Module):
    def __init__(self, num_classes, k_percent=10, gamma_focal=2.0, gamma_ga=0.5, gamma_grad=1.0, use_softmax=True):
        super(GradientAwareFocalLoss, self).__init__()
        self.num_classes = num_classes
        self.k_percent = k_percent
        self.gamma_focal = gamma_focal
        self.gamma_ga = gamma_ga
        self.gamma_grad = gamma_grad  # 控制梯度权重的强度
        self.use_softmax = use_softmax
        self.register_buffer('class_counts', torch.zeros(num_classes))
        self.register_buffer('class_weights', torch.ones(num_classes))

    def forward(self, inputs, targets):
        B, C = inputs.shape[:2]
        N = inputs.shape[2:].numel() * B  # 总样本数

        # 1. 计算概率和基础损失
        probs = F.softmax(inputs, dim=1) if self.use_softmax else inputs
        probs = probs.permute(0, *range(2, inputs.dim()), 1).contiguous().view(-1, C)
        targets = targets.view(-1)
        pt = probs.gather(1, targets.unsqueeze(1)).squeeze(1)
        ce_loss = -torch.log(pt + 1e-8)

        # 2. 启用梯度计算（关键步骤！）
        inputs_grad = inputs.detach().requires_grad_(True)  # 保留梯度计算图
        probs_grad = F.softmax(inputs_grad, dim=1) if self.use_softmax else inputs_grad
        loss_grad = F.cross_entropy(probs_grad.view(-1, C), targets, reduction='none')
        grad_outputs = torch.ones_like(loss_grad)
        gradients = torch.autograd.grad(
            outputs=loss_grad,
            inputs=inputs_grad,
            grad_outputs=grad_outputs,
            create_graph=False,
            retain_graph=True  # 保留计算图以支持后续反向传播
        )[0]  # 梯度形状与inputs相同 (B, C, ...)

        # 3. 计算梯度幅度（L2范数）
        gradients = gradients.permute(0, *range(2, gradients.dim()), 1).contiguous().view(-1, C)
        grad_magnitude = gradients.norm(p=2, dim=1)  # (N_total,)
        grad_weight = (grad_magnitude + 1e-8) ** self.gamma_grad  # 避免零梯度

        # 4. 动态类别平衡（与原实现一致）
        num_topk = max(1, int(self.k_percent / 100 * N))
        _, topk_indices = torch.topk(ce_loss, num_topk, sorted=False)
        topk_targets = targets[topk_indices]
        current_counts = torch.bincount(topk_targets, minlength=self.num_classes).float()
        self.class_counts = 0.9 * self.class_counts + 0.1 * current_counts
        effective_counts = self.class_counts + 1e-8
        self.class_weights = (1.0 / effective_counts) ** (1.0 - self.gamma_ga)
        self.class_weights = self.class_weights / self.class_weights.sum() * C

        # 5. 三重权重耦合：Focal + Class + Gradient
        focal_weight = (1 - pt) ** self.gamma_focal
        class_weight = self.class_weights[targets]

        # step 1: class-aware difficulty
        difficulty_weight = class_weight * grad_weight
        difficulty_weight = difficulty_weight / (difficulty_weight.mean())

        # step 2: sample-level hardness (focal)
        final_weight = focal_weight * difficulty_weight
        final_weight = final_weight / (final_weight.mean())

        # 6. 最终损失
        loss = (final_weight * ce_loss).mean()
        return loss

# LPLLoss_advanced实现
class LPLLoss_advanced(nn.Module):
    def __init__(self, num_classes=2, pgd_nums=50, alpha=0.1, min_class_factor=3.0):
        """
        升级版自适应LPL损失实现
        
        Args:
            num_classes: 类别数量
            pgd_nums: 基础PGD扰动的步数
            alpha: 基础扰动强度
            min_class_factor: 少数类最小扰动系数，保证少数类扰动强度至少为多数类的这个倍数
        """
        super().__init__()
        self.num_classes = num_classes
        self.pgd_nums = pgd_nums
        self.alpha = alpha
        self.min_class_factor = min_class_factor
        self.criterion = nn.CrossEntropyLoss()
        
        # 记录类别不平衡和梯度状态
        self.register_buffer('class_counts', torch.zeros(num_classes))
        self.register_buffer('class_grad_mags', torch.zeros(num_classes))
        self.momentum = 0.9  # 动量因子
    
    def update_statistics(self, logit, y):
        """更新类别统计信息和梯度幅度"""
        with torch.no_grad():
            # 更新类别计数
            batch_counts = torch.bincount(y, minlength=self.num_classes).float()
            self.class_counts = self.momentum * self.class_counts + (1 - self.momentum) * batch_counts
            
            # 估计每个类别的梯度幅度
            grad_mags = torch.zeros(self.num_classes, device=logit.device)
            for c in range(self.num_classes):
                class_mask = (y == c)
                n_samples = torch.sum(class_mask)
                
                if n_samples > 0:
                    # 获取该类别样本的logits
                    class_logits = logit[class_mask]
                    class_targets = y[class_mask]
                    
                    # 计算样本损失，作为梯度幅度估计
                    ce_loss = F.cross_entropy(class_logits, class_targets, reduction='none')
                    grad_mags[c] = ce_loss.mean().item()
            
            # 使用动量更新梯度幅度
            self.class_grad_mags = self.momentum * self.class_grad_mags + (1 - self.momentum) * grad_mags

    def compute_adaptive_params(self, logit, y):
        """计算自适应扰动参数"""
        with torch.no_grad():
            # 更新统计信息
            self.update_statistics(logit, y)
            
            # 获取类别分布信息
            total_samples = torch.sum(self.class_counts)
            class_ratios = self.class_counts / (total_samples + 1e-8)
            
            # 找出少数类和多数类
            minority_idx = torch.argmin(class_ratios).item()
            majority_idx = 1 - minority_idx  # 在二分类情况下
            
            # 计算类别不平衡比
            imbalance_ratio = class_ratios[majority_idx] / (class_ratios[minority_idx] + 1e-8)

            imbalance_ratio_tensor = torch.tensor([imbalance_ratio], device=logit.device)
            imbalance_factor = torch.clamp(imbalance_ratio_tensor, 1.0, 10.0)
            
            # 根据梯度幅度动态调整扰动强度，梯度大的类别获得更强的扰动
            grad_scale = F.softmax(self.class_grad_mags, dim=0)
            
            # 类别步数和扰动强度
            class_steps = torch.zeros(self.num_classes, device=logit.device, dtype=torch.long)
            class_alphas = torch.zeros(self.num_classes, device=logit.device, dtype=torch.float)
            
            # 设置步数范围
            max_steps = int(self.pgd_nums * 2.0)
            min_steps = max(1, int(self.pgd_nums * 0.5))
            
            # 基于类别频率反比例计算步数
            for c in range(self.num_classes):
                # 样本越少，步数越多
                freq_factor = torch.sqrt(1.0 / (class_ratios[c] + 1e-8))
                steps = min_steps + int((max_steps - min_steps) * freq_factor / (freq_factor + 1.0))
                class_steps[c] = steps
                
                # 扰动强度：基于梯度幅度和类别频率
                alpha_base = self.alpha * (1.0 + grad_scale[c].item() * 2.0)  # 梯度大的类别获得更强的扰动
                
                # 少数类得到额外的强度提升
                if c == minority_idx:
                    alpha = alpha_base * min(5.0, imbalance_factor.item() ** 0.5)
                else:
                    alpha = alpha_base
                    
                class_alphas[c] = alpha
            
            # 确保少数类的步数至少是多数类的1.5倍
            if class_steps[minority_idx] < class_steps[majority_idx] * 1.5:
                class_steps[minority_idx] = int(class_steps[majority_idx] * 1.5)
            
            # 确保少数类的扰动强度至少是多数类的min_class_factor倍
            if class_alphas[minority_idx] < class_alphas[majority_idx] * self.min_class_factor:
                class_alphas[minority_idx] = class_alphas[majority_idx] * self.min_class_factor
            
            # 为每个样本分配步数和扰动强度
            sample_steps = torch.zeros_like(y, dtype=torch.long)
            sample_alphas = torch.zeros_like(y, dtype=torch.float)
            
            # 根据样本的类别分配参数
            for c in range(self.num_classes):
                class_mask = (y == c)
                sample_steps[class_mask] = class_steps[c]
                sample_alphas[class_mask] = class_alphas[c]
            
            # 样本级别的梯度感知调整
            with torch.enable_grad():
                # 创建副本并跟踪梯度
                logit_grad = logit.detach().clone().requires_grad_(True)
                loss = F.cross_entropy(logit_grad, y, reduction='none')
                
                # 计算梯度
                grads = torch.autograd.grad(
                    outputs=loss.sum(),
                    inputs=logit_grad,
                    create_graph=False,
                    retain_graph=False
                )[0]
                
                # 使用梯度幅度作为难度指标
                sample_grad_norms = torch.norm(grads, p=2, dim=1)
                sample_difficulties = F.softmax(sample_grad_norms, dim=0)
                
                # 将难度因子映射到[0.8, 1.5]的范围
                difficulty_scales = 0.8 + 0.7 * sample_difficulties / (torch.max(sample_difficulties) + 1e-8)
                
                # 应用到样本的扰动参数
                sample_alphas = sample_alphas * difficulty_scales
                
                # 步数也可以根据难度适当调整
                steps_difficulty_scales = 1.0 + 0.5 * sample_difficulties / (torch.max(sample_difficulties) + 1e-8)
                sample_steps = (sample_steps.float() * steps_difficulty_scales).long()
            
            return sample_steps, sample_alphas
    
    def compute_adv_sign(self, logit, y, sample_alphas):
        """计算自适应对抗梯度方向"""
        with torch.no_grad():
            logit_softmax = F.softmax(logit, dim=-1)
            y_onehot = F.one_hot(y, num_classes=self.num_classes)
            
            # 计算每个类别的平均logit
            sum_class_logit = torch.matmul(
                y_onehot.permute(1, 0)*1.0, logit_softmax)
            sum_class_num = torch.sum(y_onehot, dim=0)
            
            # 防止类别不存在导致除零
            sum_class_num = torch.where(sum_class_num == 0, torch.tensor(100, device=logit.device), sum_class_num)
            mean_class_logit = torch.div(sum_class_logit, sum_class_num.reshape(-1, 1))
            
            # 计算扰动梯度方向
            grad = mean_class_logit - torch.eye(self.num_classes, device=logit.device)
            grad = torch.div(grad, torch.norm(grad, p=2, dim=0).reshape(-1, 1) + 1e-8)
            
            # 计算扰动方向标志
            mean_class_p = torch.diag(mean_class_logit)
            mean_mask = sum_class_num > 0
            mean_class_thr = torch.mean(mean_class_p[mean_mask])
            sub = mean_class_thr - mean_class_p
            sign = sub.sign()
            
            # 使用样本自适应扰动强度
            alphas_expanded = sample_alphas.unsqueeze(1).expand(-1, self.num_classes)
            adv_logit = torch.index_select(grad, 0, y) * alphas_expanded * sign[y].unsqueeze(1)
            
            return adv_logit, sub
    
    def compute_eta(self, logit, y):
        """计算最终的自适应扰动"""
        with torch.no_grad():
            # 计算自适应参数
            sample_steps, sample_alphas = self.compute_adaptive_params(logit, y)
            
            logit_clone = logit.clone()
            
            # 最大可能步数
            max_steps = torch.max(sample_steps).item()
            
            # 记录每步扰动后的结果
            logit_steps = torch.zeros(
                [max_steps + 1, logit.shape[0], self.num_classes], device=logit.device)
            
            # 初始状态
            current_logit = logit.clone()
            logit_steps[0] = current_logit
            
            # 迭代应用扰动
            for i in range(1, max_steps + 1):
                adv_logit, _ = self.compute_adv_sign(current_logit, y, sample_alphas)
                current_logit = current_logit + adv_logit
                logit_steps[i] = current_logit
            
            # 为每个样本选择对应步数的结果
            logit_news = torch.zeros_like(logit)
            for i in range(logit.shape[0]):
                step = sample_steps[i].item()
                logit_news[i] = logit_steps[step, i]
            
            # 计算扰动
            eta = logit_news - logit_clone
            
            return eta, sample_steps, sample_alphas
    
    def forward(self, models_or_logits, x=None, y=None, is_logits=False):
        """前向传播函数"""
        if is_logits:
            # 直接使用预计算的logits
            logit = models_or_logits
        else:
            # 使用模型计算logits
            logit = models_or_logits(x)
        
        # 计算自适应扰动
        eta, sample_steps, sample_alphas = self.compute_eta(logit, y)
        
        # 应用扰动
        logit_news = logit + eta
        
        # 计算损失
        loss_adv = self.criterion(logit_news, y)
        
        return loss_adv, logit, logit_news, sample_steps, sample_alphas

# 聚类相关函数
def initialize_centroids(features, k):
    """使用k-means++策略初始化聚类中心"""
    num_nodes = features.size(0)
    centroids = torch.zeros(k, features.size(1), device=features.device)
    
    # 随机选择第一个中心
    first_id = torch.randint(num_nodes, (1,)).item()
    centroids[0] = features[first_id]
    
    # 选择剩余的中心
    for i in range(1, k):
        # 计算到最近中心的距离
        distances = torch.min(torch.cdist(features, centroids[:i]), dim=1)[0]
        # 按概率选择下一个中心
        probabilities = distances / distances.sum()
        next_id = torch.multinomial(probabilities, 1).item()
        centroids[i] = features[next_id]
    
    return centroids

def check_convergence(centroids, prev_centroids, tol=1e-4):
    """检查聚类是否收敛"""
    return torch.norm(centroids - prev_centroids) < tol

def robust_node_clustering(features, k=2, temperature=0.1, max_iterations=10, labeled_features=None, labeled_classes=None):
    """基于论文的鲁棒节点聚类方法
    
    Args:
        features: 原始图的节点特征 [num_nodes, feature_dim]
        k: 聚类数量(默认2，对应二分类)
        temperature: 温度参数，控制软分配的软硬程度
        max_iterations: 最大迭代次数
        labeled_features: 有标签样本的特征 [num_labeled, feature_dim]
        labeled_classes: 有标签样本的标签 [num_labeled]
    
    Returns:
        tuple: (
            original_cluster_assignments: 原始图的聚类分配 [num_nodes, k]
            view1_cluster_assignments: 增强视图1的聚类分配 [num_nodes, k]
            view2_cluster_assignments: 增强视图2的聚类分配 [num_nodes, k]
            centroids: 聚类中心 [k, feature_dim]
        )
    """
    num_nodes = features.size(0)
    feature_dim = features.size(1)
    device = features.device

    
    # 聚类迭代过程不需要梯度，使用no_grad包裹
    with torch.no_grad():
        # 检查是否提供了有标签样本作为聚类中心
        if labeled_features is not None and labeled_classes is not None:
            # 使用有标签样本初始化聚类中心
            centroids = torch.zeros(k, feature_dim, device=device)
            
            # 按类别分组有标签样本
            for i in range(k):
                # 找到标签为i的样本
                class_indices = torch.where(labeled_classes == i)[0]
                if len(class_indices) > 0:
                    # 如果有该类的样本，计算这些样本的平均特征作为中心
                    centroids[i] = labeled_features[class_indices].mean(dim=0)
                else:
                    # 如果没有该类的样本，随机初始化
                    centroids[i] = torch.randn(feature_dim, device=device)
                    centroids[i] = F.normalize(centroids[i], p=2, dim=0)  # 归一化
                    
            # 规范化聚类中心 - 确保它们具有相同的范数
            norms = torch.norm(centroids, dim=1, keepdim=True)
            centroids = centroids / (norms + 1e-10)  # 避免除以零
            
        else:
            # 如果没有提供有标签样本，使用原始的k-means++初始化策略
            # 注意：只使用原始图特征进行中心初始化
            centroids = initialize_centroids(features, k)
        
        # 记录初始的聚类中心用于检查收敛
        prev_centroids = centroids.clone()
        
        # 只有在没有提供标签数据时才进行迭代优化
        if labeled_features is None or labeled_classes is None:
            # 迭代优化 - 完全不需要梯度
            for iter in range(max_iterations):
                # 计算每个节点到各个聚类中心的距离 - 只使用原始图特征
                distances = torch.cdist(features, centroids)  # [num_nodes, k]
                
                # 软分配 (使用Gumbel-Softmax进行可微分的聚类分配)
                logits = -distances / temperature
                cluster_assignments = F.gumbel_softmax(logits, tau=temperature, hard=False)
                
                # 更新聚类中心 - 只使用原始图特征
                new_centroids = torch.zeros_like(centroids)
                for j in range(k):
                    weights = cluster_assignments[:, j].unsqueeze(1)  # [num_nodes, 1]
                    if weights.sum() > 0:  # 避免除以零
                        new_centroids[j] = (features * weights).sum(0) / weights.sum()
                    else:
                        new_centroids[j] = centroids[j].clone()  # 保持原来的中心
                
                # 使用新的张量替代原有张量
                centroids = new_centroids
                    
                # 检查收敛
                if check_convergence(centroids, prev_centroids, tol=1e-4):
                    break
                    
                prev_centroids = centroids.clone()
    
    # 重新计算最终的聚类分配（在梯度环境下使用不同视图的features，保留梯度）
    # 为原始图特征计算聚类分配
    distances_original = torch.cdist(features, centroids)  # [num_nodes, k]
    logits_original = -distances_original / temperature
    original_cluster_assignments = F.gumbel_softmax(logits_original, tau=temperature, hard=False)
    
   
    view1_cluster_assignments = original_cluster_assignments

    
    view2_cluster_assignments = original_cluster_assignments

    # 计算聚类结果的统计信息
    with torch.no_grad():
        hard_assignments = torch.argmax(original_cluster_assignments, dim=1)
        num_class_0 = torch.sum(hard_assignments == 0).item()
        num_class_1 = torch.sum(hard_assignments == 1).item()
        total = num_class_0 + num_class_1
    
    return original_cluster_assignments, view1_cluster_assignments, view2_cluster_assignments, centroids

def compute_clustering_loss(features, cluster_assignments, centroids):
    """计算聚类损失
    
    参数:
        features: 无标签样本的特征向量
        cluster_assignments: 聚类软分配结果, shape=[num_unlabeled, k]
        centroids: 聚类中心, shape=[k, feature_dim]
    
    返回:
        loss: 聚类损失值
        num_pos: 正样本数量估计
        num_neg: 负样本数量估计
    """
    k = centroids.shape[0]
    assert k == 2, "目前只支持二分类问题（k=2）"
    
    # 获取所有样本的硬聚类分配结果
    hard_assignments = torch.argmax(cluster_assignments, dim=1)
    
    # 计算每个聚类的样本数
    cluster_0_count = torch.sum(hard_assignments == 0).item()
    cluster_1_count = torch.sum(hard_assignments == 1).item()
    
    # 确定哪个是多数类（假设为负样本）
    is_cluster_0_majority = cluster_0_count >= cluster_1_count
    
    # 记录正负样本数量
    if is_cluster_0_majority:
        num_neg = cluster_0_count
        num_pos = cluster_1_count
    else:
        num_neg = cluster_1_count
        num_pos = cluster_0_count
    
    # 计算样本到对应聚类中心的距离
    distances = torch.zeros(features.shape[0], k, device=features.device)
    for i in range(k):
        centroid = centroids[i]
        dist = torch.norm(features - centroid.unsqueeze(0), dim=1)
        distances[:, i] = dist
    
    # 计算聚类损失：希望样本接近其对应的聚类中心
    # 这里使用加权版本的欧几里得距离作为损失，根据样本对聚类的软分配来加权
    loss = torch.sum(distances * cluster_assignments) / features.shape[0]
    
    return loss, num_pos, num_neg

def train(model, g, args):
    # 设置日志
    log_dir = "logs"
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    
    # 添加时间戳到日志文件名
    timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime())
    log_file = os.path.join(log_dir, f"train_log_{args.dataset}_seed{args.seed}_{timestamp}.txt")
    
    logging.basicConfig(filename=log_file, level=logging.INFO, 
                        format='%(asctime)s - %(message)s',
                        datefmt='%Y-%m-%d %H:%M:%S')
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    logging.getLogger('').addHandler(console)
    
    # 记录运行开始时间和日志文件位置
    logging.info(f"运行开始时间: {timestamp}")
    logging.info(f"日志文件: {log_file}")
    
    features = g.ndata['feature']
    labels = g.ndata['label']
    index = list(range(len(labels)))
    
    # 对FFSD数据集的特殊处理：只使用标签为0和1的样本
    if args.dataset == 'amazon':
        index = list(range(3305, len(labels)))
    elif args.dataset == 'ffsd' or args.dataset == 'S-FFSD':
        # 使用所有样本，不过滤标签为2的无标签样本
        index = list(range(len(labels)))
        print(f"FFSD数据集：使用所有样本，数量 = {len(index)}")
        print(f"标签0（负样本）数量：{(labels == 0).sum().item()}")
        print(f"标签1（正样本）数量：{(labels == 1).sum().item()}")
        print(f"标签2（无标签样本）数量：{(labels == 2).sum().item()}")
    
    
    # 常规划分验证集和测试集
    idx_train, idx_rest, y_train, y_rest = train_test_split(index, labels[index], stratify=labels[index],
                                                            train_size=args.train_ratio,
                                                            random_state=args.seed, shuffle=True)
    idx_valid, idx_test, y_valid, y_test = train_test_split(idx_rest, y_rest, stratify=y_rest,
                                                            test_size=0.67,
                                                            random_state=args.seed, shuffle=True)
    
    # # 修改训练集，只保留一个正样本和一个负样本
    # pos_samples = [i for i in idx_train if labels[i] == 1]
    # neg_samples = [i for i in idx_train if labels[i] == 0]
    
    # # 如果正样本或负样本数量不足，记录警告
    # if len(pos_samples) == 0:
    #     logging.warning("训练集中没有正样本，无法选择一个正样本")
    # if len(neg_samples) == 0:
    #     logging.warning("训练集中没有负样本，无法选择一个负样本")
    
    # # 选择一个正样本和一个负样本作为有标签数据
    # selected_pos = [pos_samples[0]] if len(pos_samples) > 0 else []
    # selected_neg = [neg_samples[0]] if len(neg_samples) > 0 else []
    
    # 有标签训练集只包含一个正样本和一个负样本
    # 标签不为2的样本作为有标签数据
    idx_labeled = [i for i in idx_train if labels[i] != 2]
    
    # 其余训练样本作为无标签数据
    idx_unlabeled = [i for i in idx_train if i not in idx_labeled]
    
    train_mask = torch.zeros([len(labels)]).bool()
    val_mask = torch.zeros([len(labels)]).bool()
    test_mask = torch.zeros([len(labels)]).bool()
    
    # 有标签数据对应的掩码
    labeled_mask = torch.zeros([len(labels)]).bool()
    # 无标签数据对应的掩码
    unlabeled_mask = torch.zeros([len(labels)]).bool()

    train_mask[idx_train] = 1
    val_mask[idx_valid] = 1
    test_mask[idx_test] = 1
    labeled_mask[idx_labeled] = 1
    unlabeled_mask[idx_unlabeled] = 1
    
    # 评估时只考虑标签为0和1的样本
    if args.dataset == 'ffsd' or args.dataset == 'S-FFSD':
        eval_mask = (labels < 2)
        # 确保验证集和测试集只包含标签为0和1的样本
        val_mask = val_mask & eval_mask
        test_mask = test_mask & eval_mask
        # 输出统计信息
        print(f"评估时只考虑标签为0和1的样本")
        print(f"验证集：标签0={torch.sum((labels == 0) & val_mask).item()}, 标签1={torch.sum((labels == 1) & val_mask).item()}, 总计={val_mask.sum().item()}")
        print(f"测试集：标签0={torch.sum((labels == 0) & test_mask).item()}, 标签1={torch.sum((labels == 1) & test_mask).item()}, 总计={test_mask.sum().item()}")
    
    # 将数据移到适当的设备上（CPU或GPU）
    device = features.device
    train_mask = train_mask.to(device)
    val_mask = val_mask.to(device)
    test_mask = test_mask.to(device)
    labeled_mask = labeled_mask.to(device)
    unlabeled_mask = unlabeled_mask.to(device)
    if isinstance(labels, numpy.ndarray):
        labels = torch.tensor(labels).to(device)
    
    # logging.info('有标签数据: 正样本数: %d, 负样本数: %d, 总数: %d', len(selected_pos), len(selected_neg), len(idx_labeled))
    logging.info('无标签数据数量: %d', len(idx_unlabeled))
    logging.info('训练集总数: %d, 验证集: %d, 测试集: %d', train_mask.sum().item(), val_mask.sum().item(), test_mask.sum().item())
    
    # 初始化优化器
    optimizer = torch.optim.Adam(
        list(model.parameters()), 
        lr=0.001, #0.003 # 0.001
        weight_decay=1e-4 #3e-5 #5e-6
    )
    best_f1, final_tf1, final_trec, final_tpre, final_tmf1, final_tauc = 0., 0., 0., 0., 0., 0.
    best_acc1, best_acc0, best_gmean = 0., 0., 0.

    # 添加变量跟踪每个epoch的最佳测试结果
    best_test_auc = 0.0
    best_test_f1 = 0.0
    best_test_rec = 0.0
    best_test_pre = 0.0
    best_test_acc1 = 0.0
    best_test_acc0 = 0.0
    best_test_gmean = 0.0
    best_test_epoch = 0

    # 计算正负样本比例权重
    pos_weight = labels[labeled_mask].sum().item()
    neg_weight = (1 - labels[labeled_mask]).sum().item()
    weight = neg_weight / pos_weight if pos_weight > 0 else 1.0
    logging.info('cross entropy weight: %f (不使用)', weight)
    
    # 初始化损失函数
    gradient_aware_focal = GradientAwareFocalLoss(num_classes=2, 
                                                k_percent=20, 
                                                gamma_focal=2, 
                                                gamma_ga=0.5, 
                                                gamma_grad=1, 
                                                use_softmax=True).to(device)
    
    # 初始化自适应LPL损失函数
    adaptive_lpl_loss = LPLLoss_advanced(
        num_classes=2,
        pgd_nums=30,
        alpha=0.05,
        min_class_factor=3.5
    ).to(device)
    
    # 控制是否使用原始模型输出和聚类结果的开关
    use_original_pseudo_labels = True
    use_clustering_pseudo_labels = True
    
    # 聚类和伪标签相关配置
    clustering_temperature = 0.1  # 控制聚类软分配的软硬程度
    fixed_cluster_epochs = 50  # 前几个epoch使用有标签数据固定聚类中心
    
    # 记录相关信息
    logging.info('伪标签策略：')
    if use_original_pseudo_labels and use_clustering_pseudo_labels:
        logging.info('  使用模型输出和聚类结果的融合')
    elif use_original_pseudo_labels:
        logging.info('  仅使用模型输出')
    elif use_clustering_pseudo_labels:
        logging.info('  仅使用聚类结果')
    else:
        logging.info('  未启用伪标签')
    
    if not hasattr(args, 'mu_rampup'):
        args.mu_rampup = True  # 默认启用rampup
    if not hasattr(args, 'consistency_rampup'):
        args.consistency_rampup = None  # 默认使用总epoch数
    if not hasattr(args, 'mu'):
        args.mu = 1.5
    
    def get_current_mu(epoch, args):
        if args.mu_rampup:
            # Consistency ramp-up from https://arxiv.org/abs/1610.02242
            if args.consistency_rampup is None:
                #args.consistency_rampup = args.num_epochs
                args.consistency_rampup = 500
            return args.mu * sigmoid_rampup(epoch, args.consistency_rampup)
        else:
            return args.mu
    
    def sigmoid_rampup(current, rampup_length):
        '''Exponential rampup from https://arxiv.org/abs/1610.02242'''
        if rampup_length == 0:
            return 1.0
        else:
            current = np.clip(current, 0.0, rampup_length)
            phase = 1.0 - current / rampup_length
            return float(np.exp(-5.0 * phase * phase))    
    

    # 训练循环
    time_start = time.time()
    for e in range(args.epoch):
        # 获取当前epoch的mu值
        current_mu = get_current_mu(e, args)
        model.train()
        out_original, h_original = model(features, 0)  # out是features, h是log_probs
        out1, h1 = model(features, 1)
        out2, h2 = model(features, 2)
        
        # 只对有标签数据生成正负样本对
        labeled_indices = torch.nonzero(labeled_mask).squeeze(-1).cpu().numpy()
        labels_np = labels.cpu().numpy()

        positive_pairs, negative_pairs = generate_contrastive_pairs(labeled_indices, labels_np)
  
        
        # 计算对比损失
        if len(positive_pairs) > 0:
            # 视图1的对比损失
            z_i_1 = h1[torch.tensor([p[0] for p in positive_pairs], dtype=torch.long).to(device)]
            z_j_1 = h1[torch.tensor([p[1] for p in positive_pairs], dtype=torch.long).to(device)]
            contrastive_loss_1 = nt_xent_loss(z_i_1, z_j_1)
            
            # 视图2的对比损失
            z_i_2 = h2[torch.tensor([p[0] for p in positive_pairs], dtype=torch.long).to(device)]
            z_j_2 = h2[torch.tensor([p[1] for p in positive_pairs], dtype=torch.long).to(device)]
            contrastive_loss_2 = nt_xent_loss(z_i_2, z_j_2)
            
            contrastive_loss = (contrastive_loss_1 + contrastive_loss_2) / 2
        else:
            contrastive_loss = torch.tensor(0.0).to(device)
        
        # 有标签数据的损失计算
        # 分类损失（nll_loss）- 只使用两个增强视图的损失，与参考文件一致
        classification_loss_1 = F.nll_loss(out1[labeled_mask], labels[labeled_mask])
        classification_loss_2 = F.nll_loss(out2[labeled_mask], labels[labeled_mask])
        
        # 一致性损失 - 只使用两个增强视图之间的一致性，与参考文件一致
        consistency_loss_1_2 = F.mse_loss(h1, h2)  # 两个增强视图之间的特征一致性
        
        # 初始化无标签损失
        pseudo_label_loss = torch.tensor(0.0, device=device)
        pseudo_lpl_loss = torch.tensor(0.0, device=device)
        clustering_loss = torch.tensor(0.0, device=device)
        num_pos = 0
        num_neg = 0
        
        # 只有在有无标签数据的情况下才进行聚类和伪标签处理
        if unlabeled_mask.sum() > 0:
            # 提取有标签和无标签数据的特征 - 使用logits (h)进行聚类，与参考文件一致
            labeled_features = h_original[labeled_mask]
            labeled_classes = labels[labeled_mask]
            h_orig_unlabeled = h_original[unlabeled_mask]
            h1_unlabeled = h1[unlabeled_mask]
            h2_unlabeled = h2[unlabeled_mask]


            # 获取无标签节点索引
            unlabeled_indices = torch.nonzero(unlabeled_mask).squeeze(-1)
            
            # 聚类处理
            if e < fixed_cluster_epochs:
                # 前几个epoch使用有标签数据固定聚类中心
                cluster_assignments_orig, cluster_assignments_view1, cluster_assignments_view2, centroids_orig = robust_node_clustering(
                    h_orig_unlabeled,
                    k=2,
                    temperature=0.8,
                    max_iterations=10,
                    labeled_features=labeled_features,
                    labeled_classes=labeled_classes
                )
            else:
                # 之后的epoch不再使用有标签数据固定中心
                cluster_assignments_orig, cluster_assignments_view1, cluster_assignments_view2, centroids_orig = robust_node_clustering(
                    h_orig_unlabeled,
                    k=2,
                    temperature=0.8,
                    max_iterations=10
                )
            
            # 创建合并的特征和分配
            all_features = torch.cat([h_orig_unlabeled, h1_unlabeled, h2_unlabeled], dim=0)
            all_assignments = torch.cat([cluster_assignments_orig, cluster_assignments_view1, cluster_assignments_view2], dim=0)
            
            # 计算统一的聚类损失
            clustering_loss, num_pos_all, num_neg_all = compute_clustering_loss(
                all_features, 
                all_assignments, 
                centroids_orig  # 使用原始图的聚类中心
            )
            
            # 伪标签生成逻辑
            with torch.no_grad():
                # 初始化伪标签张量
                final_pseudo_labels_for_batch_unlabeled = torch.tensor([], dtype=torch.long, device=device)
                
                if use_original_pseudo_labels and use_clustering_pseudo_labels:
                    # 融合模型输出和聚类结果
                    orig_logits_unlabeled = out_original[unlabeled_mask]  # 模型输出概率 [P(0), P(1)]
                    orig_probs_unlabeled = F.softmax(orig_logits_unlabeled, dim=1) # 模型输出概率 [P(负), P(正)]
                    
                    # 确定聚类0和聚类1哪个是多数类（负样本），哪个是少数类（正样本）
                    temp_cluster_hard_labels = torch.argmax(cluster_assignments_orig, dim=1) # 初步判断样本属于哪个聚类
                    count_c0 = torch.sum(temp_cluster_hard_labels == 0).item()
                    count_c1 = torch.sum(temp_cluster_hard_labels == 1).item()
                    
                    # 将聚类概率对齐为 [P(0), P(1)] 格式
                    aligned_cluster_probs = cluster_assignments_orig.clone()
                    if count_c0 < count_c1:
                        # 如果聚类0是少数类（正样本），交换列
                        aligned_cluster_probs[:, 0] = cluster_assignments_orig[:, 1]
                        aligned_cluster_probs[:, 1] = cluster_assignments_orig[:, 0]
                    
                    # 融合概率
                    combined_probs_unlabeled = (orig_probs_unlabeled + aligned_cluster_probs) / 2.0
                    final_pseudo_labels_for_batch_unlabeled = torch.argmax(combined_probs_unlabeled, dim=1)
                    
                elif use_clustering_pseudo_labels:
                    # 仅使用聚类结果
                    temp_cluster_hard_labels = torch.argmax(cluster_assignments_orig, dim=1) # 初步判断样本属于哪个聚类
                    count_c0 = torch.sum(temp_cluster_hard_labels == 0).item()
                    count_c1 = torch.sum(temp_cluster_hard_labels == 1).item()
                    
                    if count_c0 >= count_c1:
                        # 聚类0是多数类（标签0），聚类1是少数类（标签1）
                        final_pseudo_labels_for_batch_unlabeled = temp_cluster_hard_labels
                    else:
                        # 聚类1是多数类（标签0），聚类0是少数类（标签1）
                        final_pseudo_labels_for_batch_unlabeled = 1 - temp_cluster_hard_labels
                
                elif use_original_pseudo_labels:
                    # 仅使用模型输出
                    orig_logits_unlabeled = h_orig_unlabeled
                    final_pseudo_labels_for_batch_unlabeled = torch.argmax(orig_logits_unlabeled, dim=1)
                
                # 记录伪标签中的正负样本数量
                if len(final_pseudo_labels_for_batch_unlabeled) > 0:
                    pseudo_pos_count = torch.sum(final_pseudo_labels_for_batch_unlabeled == 1).item()
                    pseudo_neg_count = torch.sum(final_pseudo_labels_for_batch_unlabeled == 0).item()
                    logging.info('Epoch %d 伪标签: 正样本: %d (%.1f%%), 负样本: %d (%.1f%%)', 
                                e, pseudo_pos_count, pseudo_pos_count/len(final_pseudo_labels_for_batch_unlabeled)*100,
                                pseudo_neg_count, pseudo_neg_count/len(final_pseudo_labels_for_batch_unlabeled)*100)
            
            # 计算伪标签损失
            if len(final_pseudo_labels_for_batch_unlabeled) > 0:
                # 使用GradientAwareFocalLoss计算伪标签损失
                pseudo_logits_1 = out1[unlabeled_mask]
                pseudo_logits_2 = out2[unlabeled_mask]
                consistent_pseudo_labels = final_pseudo_labels_for_batch_unlabeled
                # 计算伪标签分类损失
                pseudo_label_loss_1 = gradient_aware_focal(
                                        pseudo_logits_1, 
                                        final_pseudo_labels_for_batch_unlabeled
                                    )
                pseudo_label_loss_2 = gradient_aware_focal(
                                        pseudo_logits_2, 
                                        consistent_pseudo_labels
                                    )
                
                pseudo_label_loss = (pseudo_label_loss_1 + pseudo_label_loss_2) / 3.0
                
                # 计算LPL损失（只用于原始视图的特征）
                # 视图1的自适应LPL损失
                adap_lpl_loss_1, _, _, steps_1, alphas_1 = adaptive_lpl_loss(pseudo_logits_1, None, consistent_pseudo_labels, is_logits=True)
                                    
                # 视图2的自适应LPL损失
                adap_lpl_loss_2, _, _, steps_2, alphas_2 = adaptive_lpl_loss(pseudo_logits_2, None, consistent_pseudo_labels, is_logits=True)
                                    
                # 组合自适应LPL损失
                pseudo_lpl_loss = (adap_lpl_loss_1 + adap_lpl_loss_2) / 2
        
        # 综合损失计算 - 严格按照参考文件
        labeled_loss = (classification_loss_1 + classification_loss_2) / 2
        labeled_loss1 = F.nll_loss(out_original[labeled_mask], labels[labeled_mask])
    
        consistency_loss = consistency_loss_1_2
        # 最终损失 = 有标签损失 + 一致性损失 + 无标签损失
        # 与参考文件保持一致：不使用权重动态调整
        loss = labeled_loss1


        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # 输出各种损失的值
        if e % 10 == 0:
            log_str = f'Epoch {e}, 有标签损失: {labeled_loss:.4f}, 一致性损失: {consistency_loss:.4f}'
            if unlabeled_mask.sum() > 0:
                log_str += f', 伪标签损失: {pseudo_label_loss:.4f}, LPL损失: {pseudo_lpl_loss:.4f}, 聚类损失: {clustering_loss:.4f}'
            log_str += f', 总损失: {loss:.4f}'
            logging.info(log_str)
        
        model.eval()
        # 使用原始视图的输出进行验证和测试
        out_original, h_original = model(features, 0)  # 重新获取原始视图的输出
        probs = torch.exp(out_original).cpu().detach()  # 将log_softmax转换回概率
        
        # 验证集评估
        f1, thres = get_best_f1(labels[val_mask], probs[val_mask])
        preds = numpy.zeros_like(labels.cpu().numpy())
        preds[probs[:, 1].detach().cpu().numpy() > thres] = 1
        
        # 计算验证集的ACC1和ACC0
        val_pos_mask = (labels[val_mask] == 1)
        val_neg_mask = (labels[val_mask] == 0)
        
        val_acc1 = accuracy_score(labels[val_mask][val_pos_mask].cpu().numpy(), preds[val_mask][val_pos_mask]) if val_pos_mask.sum() > 0 else 0
        val_acc0 = accuracy_score(labels[val_mask][val_neg_mask].cpu().numpy(), preds[val_mask][val_neg_mask]) if val_neg_mask.sum() > 0 else 0
        
        # 计算验证集的召回率
        val_rec1 = recall_score(labels[val_mask][val_pos_mask].cpu().numpy(), preds[val_mask][val_pos_mask]) if val_pos_mask.sum() > 0 else 0
        val_rec0 = 1 - (preds[val_mask][val_neg_mask].sum() / val_neg_mask.sum().item()) if val_neg_mask.sum() > 0 else 0
        
        # 计算验证集的Gmean
        val_gmean = geometric_mean(val_rec0, val_rec1)
        
        # 测试集评估
        trec = recall_score(labels[test_mask].cpu().numpy(), preds[test_mask], average='binary')
        tpre = precision_score(labels[test_mask].cpu().numpy(), preds[test_mask], average='binary')
        tmf1 = f1_score(labels[test_mask].cpu().numpy(), preds[test_mask], average='macro')
        tauc = roc_auc_score(labels[test_mask].cpu().numpy(), probs[test_mask][:, 1].detach().cpu().numpy())
        
        # 计算测试集的ACC1和ACC0
        test_pos_mask = (labels[test_mask] == 1)
        test_neg_mask = (labels[test_mask] == 0)
        
        test_acc1 = accuracy_score(labels[test_mask][test_pos_mask].cpu().numpy(), preds[test_mask][test_pos_mask]) if test_pos_mask.sum() > 0 else 0
        test_acc0 = accuracy_score(labels[test_mask][test_neg_mask].cpu().numpy(), preds[test_mask][test_neg_mask]) if test_neg_mask.sum() > 0 else 0
        
        # 计算测试集的召回率
        test_rec1 = recall_score(labels[test_mask][test_pos_mask].cpu().numpy(), preds[test_mask][test_pos_mask]) if test_pos_mask.sum() > 0 else 0
        test_rec0 = 1 - (preds[test_mask][test_neg_mask].sum() / test_neg_mask.sum().item()) if test_neg_mask.sum() > 0 else 0
        
        # 计算测试集的Gmean
        test_gmean = geometric_mean(test_rec0, test_rec1)

        # 记录每个epoch的测试结果
        logging.info('Epoch %d, 测试集 AUC: %.4f, PRE: %.4f, REC: %.4f, MF1: %.4f, ACC1: %.4f, ACC0: %.4f, Gmean: %.4f', 
                    e, tauc, tpre, trec, tmf1, test_acc1, test_acc0, test_gmean)

        # 如果当前epoch的测试AUC更好，则更新最佳测试结果
        if tauc > best_test_auc:
            best_test_auc = tauc
            best_test_f1 = tmf1
            best_test_rec = trec
            best_test_pre = tpre
            best_test_acc1 = test_acc1
            best_test_acc0 = test_acc0
            best_test_gmean = test_gmean
            best_test_epoch = e
            
            # 记录更新的最佳测试结果
            logging.info('更新最佳测试结果 (Epoch %d): AUC %.4f, PRE %.4f, REC %.4f, MF1 %.4f, ACC1 %.4f, ACC0 %.4f, Gmean %.4f', 
                        e, best_test_auc, best_test_pre, best_test_rec, best_test_f1, best_test_acc1, best_test_acc0, best_test_gmean)

        if best_f1 < f1:
            best_f1 = f1
            final_trec = trec
            final_tpre = tpre
            final_tmf1 = tmf1
            final_tauc = tauc
            best_acc1 = test_acc1
            best_acc0 = test_acc0
            best_gmean = test_gmean
            
        # 记录验证集指标
        logging.info('Epoch %d, val mf1: %.4f, val ACC1: %.4f, val ACC0: %.4f, val Gmean: %.4f (best %.4f)', 
                    e, f1, val_acc1, val_acc0, val_gmean, best_f1)

    time_end = time.time()
    logging.info('time cost: %.2f s', time_end - time_start)
    
    # 输出基于验证集选择的最佳模型的测试结果
    logging.info('基于验证集最佳模型的测试结果: REC %.2f PRE %.2f MF1 %.2f AUC %.2f ACC1 %.2f ACC0 %.2f Gmean %.2f', 
                final_trec*100, final_tpre*100, final_tmf1*100, final_tauc*100, 
                best_acc1*100, best_acc0*100, best_gmean*100)
    
    # 输出在测试集上观察到的最佳结果
    logging.info('测试集上观察到的最佳结果 (Epoch %d): AUC %.2f PRE %.2f REC %.2f MF1 %.2f ACC1 %.2f ACC0 %.2f Gmean %.2f', 
                best_test_epoch, best_test_auc*100, best_test_pre*100, best_test_rec*100, best_test_f1*100,
                best_test_acc1*100, best_test_acc0*100, best_test_gmean*100)
    
    return final_tmf1, final_tauc, best_acc1, best_acc0, best_gmean


# threshold adjusting for best macro f1
def get_best_f1(labels, probs):
    best_f1, best_thre = 0, 0
    # 将张量转移到CPU并转为numpy数组
    if isinstance(labels, torch.Tensor):
        labels = labels.cpu().numpy()
    if isinstance(probs, torch.Tensor):
        probs = probs.detach().cpu().numpy()
        
    for thres in np.linspace(0.05, 0.95, 19):
        preds = np.zeros_like(labels)
        preds[probs[:,1] > thres] = 1
        mf1 = f1_score(labels, preds, average='macro')
        if mf1 > best_f1:
            best_f1 = mf1
            best_thre = thres
    return best_f1, best_thre


if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='BWGNN')
    parser.add_argument("--dataset", type=str, default="amazon",
                        help="Dataset for this model (yelp/amazon/tfinance/tsocial)")
    parser.add_argument("--train_ratio", type=float, default=0.4, help="Training ratio")
    parser.add_argument("--hid_dim", type=int, default=64, help="Hidden layer dimension")
    parser.add_argument("--order", type=int, default=2, help="Order C in Beta Wavelet")
    parser.add_argument("--homo", type=int, default=1, help="1 for BWGNN(Homo) and 0 for BWGNN(Hetero)")
    parser.add_argument("--epoch", type=int, default=100, help="The max number of epochs")
    parser.add_argument("--run", type=int, default=1, help="Running times")
    parser.add_argument("--seed", type=int, default=64, help="Random seed")

    args = parser.parse_args()
    print(args)
    
    # 设置随机种子
    set_seed(args.seed)
    
    dataset_name = args.dataset
    homo = args.homo
    order = args.order
    h_feats = args.hid_dim
    graph = Dataset(dataset_name, homo).graph
    in_feats = graph.ndata['feature'].shape[1]
    num_classes = 2

    def get_augmented_graph(original_graph, drop_rate=0.2):
        """
        对图的节点特征进行删除增强。

        参数:
        original_graph: 原始图
        drop_rate: 删除比例,默认0.2

        返回:
        增强后的图
        """
        graph = original_graph.clone()
        feat_data = graph.ndata['feature']
        feat_mask = torch.rand(feat_data.size(1), device=feat_data.device) > drop_rate
        feat_aug = feat_data.clone()
        feat_aug[:, ~feat_mask] = 0
        graph.ndata['feature'] = feat_aug
        return graph

    # 生成两个增强视图
    graph1 = get_augmented_graph(graph, 0.2)
    graph2 = get_augmented_graph(graph, 0.3)

    if args.run == 1:
        if homo:
            model = BWGNN(in_feats, h_feats, num_classes, initial_graph=graph, augmented_graph1=graph1, augmented_graph2=graph2, d=order)
        else:
            model = BWGNN_Hetero(in_feats, h_feats, num_classes, initial_graph=graph, augmented_graph1=graph1, augmented_graph2=graph2, d=order)
        train(model, graph, args)

    else:
        final_mf1s, final_aucs, final_acc1s, final_acc0s, final_gmeans = [], [], [], [], []
        for tt in range(args.run):
            # 每次运行重新生成增强视图
            graph1 = get_augmented_graph(graph, 0.2)
            graph2 = get_augmented_graph(graph, 0.3)
            
            if homo:
                model = BWGNN(in_feats, h_feats, num_classes, initial_graph=graph, augmented_graph1=graph1, augmented_graph2=graph2, d=order)
            else:
                model = BWGNN_Hetero(in_feats, h_feats, num_classes, initial_graph=graph, augmented_graph1=graph1, augmented_graph2=graph2, d=order)
            mf1, auc, acc1, acc0, gmean = train(model, graph, args)
            final_mf1s.append(mf1)
            final_aucs.append(auc)
            final_acc1s.append(acc1)
            final_acc0s.append(acc0)
            final_gmeans.append(gmean)
            
        final_mf1s = np.array(final_mf1s)
        final_aucs = np.array(final_aucs)
        final_acc1s = np.array(final_acc1s)
        final_acc0s = np.array(final_acc0s)
        final_gmeans = np.array(final_gmeans)
        
        logging.info('MF1-mean: %.2f, MF1-std: %.2f, AUC-mean: %.2f, AUC-std: %.2f', 
                   100 * np.mean(final_mf1s), 100 * np.std(final_mf1s),
                   100 * np.mean(final_aucs), 100 * np.std(final_aucs))
        logging.info('ACC1-mean: %.2f, ACC1-std: %.2f, ACC0-mean: %.2f, ACC0-std: %.2f', 
                   100 * np.mean(final_acc1s), 100 * np.std(final_acc1s),
                   100 * np.mean(final_acc0s), 100 * np.std(final_acc0s))
        logging.info('Gmean-mean: %.2f, Gmean-std: %.2f', 
                   100 * np.mean(final_gmeans), 100 * np.std(final_gmeans))
